import sklearn
import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels


def _check_auto(param):
    return (isinstance(param, str) and (param == 'auto'))


class _BaseRKHSIV:
    def __init__(self, *args, **kwargs):
        pass

    def _get_delta(self, n: int) -> float:
        '''
        delta -> Critical radius
        '''
        delta_scale = 5 if _check_auto(self.delta_scale) else self.delta_scale
        delta_exp = .4 if _check_auto(self.delta_exp) else self.delta_exp
        return delta_scale / (n**(delta_exp))

    def _get_alpha_scale(self):
        return 60 if _check_auto(self.alpha_scale) else self.alpha_scale

    def _get_alpha_scales(self):
        return ([c for c in np.geomspace(0.1, 1e5, self.n_alphas)]
                if _check_auto(self.alpha_scales) else self.alpha_scales)

    def _get_alpha(self, delta, alpha_scale):
        return alpha_scale * (delta**4)

    def _get_kernel(self, X, Y=None):
        if callable(self.kernel):
            params = self.kernel_params or {}
        else:
            params = {"gamma": self.gamma,
                        "degree": self.degree,
                        "coef0": self.coef0}
        return pairwise_kernels(X, Y, metric=self.kernel,
                                filter_params=True, **params)

    def _get_gamma_gm(self, condition):
        if _check_auto(self.gamma_gm):
            params = {"squared": True}
            K_condition_euclidean = sklearn.metrics.pairwise_distances(X = condition, metric='euclidean', n_jobs=-1, **params)
            # gamma_gm = 1./(condition.shape[1] * np.median(K_condition_euclidean[np.tril_indices(condition.shape[0],-1)]))
            gamma_gm = 1./(np.median(K_condition_euclidean[np.tril_indices(condition.shape[0],-1)]))
            return gamma_gm
        else:
            return self.gamma_gm

    def _get_kernel_gm(self, X, Y=None, gamma_gm=0.1):
        params = {"gamma": gamma_gm}
        return pairwise_kernels(X, Y, metric='rbf', filter_params=True, **params)

    def _get_kernel_hq(self, X, Y=None, gamma_h=0.01):
        params = {"gamma": gamma_h}
        return pairwise_kernels(X, Y, metric='rbf', filter_params=True, **params)